{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Prediction based on CaDRReS-Sc pre-trained model\n",
    "This notebook show an example of how load a pre-trained CaDRReS-SC model and predict drug response based on new data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-12-04T02:37:42.695742Z",
     "start_time": "2020-12-04T02:37:41.056123Z"
    }
   },
   "outputs": [],
   "source": [
    "import sys, os, pickle\n",
    "from collections import Counter\n",
    "import importlib\n",
    "from ipywidgets import widgets\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "scriptpath = '..'\n",
    "sys.path.append(os.path.abspath(scriptpath))\n",
    "\n",
    "from cadrres_sc import pp, model, evaluation, utility"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Read pre-trained model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-12-04T02:37:42.699136Z",
     "start_time": "2020-12-04T02:37:42.697188Z"
    }
   },
   "outputs": [],
   "source": [
    "model_dir = '../example_result/'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-12-04T02:37:42.708045Z",
     "start_time": "2020-12-04T02:37:42.700518Z"
    }
   },
   "outputs": [],
   "source": [
    "obj_function = widgets.Dropdown(options=['cadrres-wo-sample-bias', 'cadrres-wo-sample-bias-weight'], description='Objetice function')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-12-04T02:37:42.713078Z",
     "start_time": "2020-12-04T02:37:42.709492Z"
    }
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "bafc0fa907774e70b2428e886f6ce59e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Dropdown(description='Objetice function', options=('cadrres-wo-sample-bias', 'cadrres-wo-sample-bias-weight'),…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "#choose which model you have trained previously\n",
    "display(obj_function)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load the pre-trained model based on your selection\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-12-04T02:37:42.717566Z",
     "start_time": "2020-12-04T02:37:42.714275Z"
    }
   },
   "outputs": [],
   "source": [
    "model_spec_name = obj_function.value\n",
    "model_file = model_dir + '{}_param_dict.pickle'.format(model_spec_name)\n",
    "\n",
    "cadrres_model = model.load_model(model_file)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Read test data\n",
    "Again, for this example we load GDSC dataset.\n",
    "@TODO: GDSC dataset using only essential gene list?\n",
    "\n",
    "Note: GDSC_exp.tsv can be downloaded from https://www.dropbox.com/s/3v576mspw5yewbm/GDSC_exp.tsv?dl=0\n",
    "\n",
    "## Notes for other test data\n",
    "\n",
    "You can apply the model to other gene expression dataset. The input gene expression matrix should have been normalized, i.e. **for each sample, expression values are comparable across genes**. \n",
    "\n",
    "In this example the gene expression matrix provided by GDSC is already normalized using RMA.\n",
    "\n",
    "For RNA-seq data, read count should be normalized by gene length, using normalization methods such as TPM."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-12-04T02:37:45.949356Z",
     "start_time": "2020-12-04T02:37:42.718689Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dataframe shape: (17419, 1018) \n",
      "\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>906826</th>\n",
       "      <th>687983</th>\n",
       "      <th>910927</th>\n",
       "      <th>1240138</th>\n",
       "      <th>1240139</th>\n",
       "      <th>906792</th>\n",
       "      <th>910688</th>\n",
       "      <th>1240135</th>\n",
       "      <th>1290812</th>\n",
       "      <th>907045</th>\n",
       "      <th>...</th>\n",
       "      <th>753584</th>\n",
       "      <th>907044</th>\n",
       "      <th>998184</th>\n",
       "      <th>908145</th>\n",
       "      <th>1659787</th>\n",
       "      <th>1298157</th>\n",
       "      <th>1480372</th>\n",
       "      <th>1298533</th>\n",
       "      <th>930299</th>\n",
       "      <th>905954.1</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>GENE</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>A1BG</th>\n",
       "      <td>6.208447</td>\n",
       "      <td>5.025810</td>\n",
       "      <td>5.506955</td>\n",
       "      <td>4.208349</td>\n",
       "      <td>3.399366</td>\n",
       "      <td>4.917872</td>\n",
       "      <td>3.828088</td>\n",
       "      <td>5.146903</td>\n",
       "      <td>3.107543</td>\n",
       "      <td>5.062066</td>\n",
       "      <td>...</td>\n",
       "      <td>4.272172</td>\n",
       "      <td>3.435025</td>\n",
       "      <td>4.930052</td>\n",
       "      <td>2.900213</td>\n",
       "      <td>4.523712</td>\n",
       "      <td>5.074951</td>\n",
       "      <td>2.957153</td>\n",
       "      <td>3.089628</td>\n",
       "      <td>4.047364</td>\n",
       "      <td>5.329524</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>A1CF</th>\n",
       "      <td>2.981775</td>\n",
       "      <td>2.947547</td>\n",
       "      <td>2.872071</td>\n",
       "      <td>3.075478</td>\n",
       "      <td>2.853231</td>\n",
       "      <td>3.221491</td>\n",
       "      <td>2.996355</td>\n",
       "      <td>2.893977</td>\n",
       "      <td>2.755668</td>\n",
       "      <td>2.985650</td>\n",
       "      <td>...</td>\n",
       "      <td>2.941659</td>\n",
       "      <td>3.155536</td>\n",
       "      <td>2.983619</td>\n",
       "      <td>3.118312</td>\n",
       "      <td>2.975409</td>\n",
       "      <td>2.905804</td>\n",
       "      <td>2.944488</td>\n",
       "      <td>2.780003</td>\n",
       "      <td>2.870819</td>\n",
       "      <td>2.926353</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>2 rows × 1018 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "        906826    687983    910927   1240138   1240139    906792    910688  \\\n",
       "GENE                                                                         \n",
       "A1BG  6.208447  5.025810  5.506955  4.208349  3.399366  4.917872  3.828088   \n",
       "A1CF  2.981775  2.947547  2.872071  3.075478  2.853231  3.221491  2.996355   \n",
       "\n",
       "       1240135   1290812    907045  ...    753584    907044    998184  \\\n",
       "GENE                                ...                                 \n",
       "A1BG  5.146903  3.107543  5.062066  ...  4.272172  3.435025  4.930052   \n",
       "A1CF  2.893977  2.755668  2.985650  ...  2.941659  3.155536  2.983619   \n",
       "\n",
       "        908145   1659787   1298157   1480372   1298533    930299  905954.1  \n",
       "GENE                                                                        \n",
       "A1BG  2.900213  4.523712  5.074951  2.957153  3.089628  4.047364  5.329524  \n",
       "A1CF  3.118312  2.975409  2.905804  2.944488  2.780003  2.870819  2.926353  \n",
       "\n",
       "[2 rows x 1018 columns]"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gene_exp_df = pd.read_csv('../data/GDSC/GDSC_exp.tsv', sep='\\t', index_col=0)\n",
    "gene_exp_df = gene_exp_df.groupby(gene_exp_df.index).mean()\n",
    "print(\"Dataframe shape:\", gene_exp_df.shape, \"\\n\")\n",
    "gene_exp_df.head(2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Calculate fold-change\n",
    "We normalized baseline gene expression values for each gene by computing fold-changes compared to the median value across cell-lines"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-12-04T02:37:49.916450Z",
     "start_time": "2020-12-04T02:37:45.950650Z"
    }
   },
   "outputs": [],
   "source": [
    "cell_line_log2_mean_fc_exp_df, cell_line_mean_exp_df = pp.gexp.normalize_log2_mean_fc(gene_exp_df)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Read essential genes list\n",
    "\n",
    "Or in case you want your training using one specific set of genes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-12-04T02:37:49.922448Z",
     "start_time": "2020-12-04T02:37:49.918253Z"
    }
   },
   "outputs": [],
   "source": [
    "ess_gene_list = utility.get_gene_list('../data/essential_genes.txt')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Calculate kernel feature \n",
    "\n",
    "Based on all given cell line samples with gene expression profiles and a list of genes (e.g. essential gene list). This step might take a bit more time than usual."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "start_time": "2020-12-04T02:37:40.976Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Calculating kernel features based on 1610 common genes\n",
      "(17419, 1018) (17419, 1018)\n",
      "100 of 1018 (8.24)s\n",
      "200 of 1018 (8.23)s\n",
      "300 of 1018 (8.25)s\n",
      "400 of 1018 (8.18)s\n",
      "500 of 1018 (8.16)s\n"
     ]
    }
   ],
   "source": [
    "test_kernel_df = pp.gexp.calculate_kernel_feature(cell_line_log2_mean_fc_exp_df, cell_line_log2_mean_fc_exp_df, ess_gene_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "start_time": "2020-12-04T02:37:40.976Z"
    }
   },
   "outputs": [],
   "source": [
    "print(\"Dataframe shape:\", test_kernel_df.shape, \"\\n\")\n",
    "test_kernel_df.head(2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Drug response prediction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "start_time": "2020-12-04T02:37:40.977Z"
    }
   },
   "outputs": [],
   "source": [
    "print('Predicting drug response using CaDRReS: {}'.format(model_spec_name))\n",
    "pred_df, P_test_df= model.predict_from_model(cadrres_model, test_kernel_df, model_spec_name)\n",
    "print('done!')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Inspecting the model predictions and save the predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "start_time": "2020-12-04T02:37:40.978Z"
    }
   },
   "outputs": [],
   "source": [
    "pred_df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "start_time": "2020-12-04T02:37:40.979Z"
    }
   },
   "outputs": [],
   "source": [
    "P_test_df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "start_time": "2020-12-04T02:37:40.980Z"
    }
   },
   "outputs": [],
   "source": [
    "print('Saving ' + model_dir + '{}_test_pred.csv'.format(model_spec_name))\n",
    "pred_df.to_csv(model_dir + '{}_test_pred.csv'.format(model_spec_name))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "**Authors:** [Chayaporn Suphavilai](mailto:@.com), [Rafael Peres da Silva](), Genome Institute of Singapore, Nagarajan Lab, November, 2020\n",
    "\n",
    "---"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Reproducibility tips from https://github.com/jupyter-guide/ten-rules-jupyter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": false,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
